package com.hcf.streaming

import java.security.MessageDigest
import java.text.SimpleDateFormat
import java.util.Date

import com.google.gson.{JsonObject, JsonParser}
import org.apache.commons.lang.StringUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SaveMode, SparkSession}

import scala.collection.mutable


/**
 * id强打通逻辑
 * 1.将用户数据打散
 * 2.id强打通，将同一个用户的任一id相同的聚合到一行， 如 {id1(1,2,3), id2(4,5,6), id3(7,8,9)}
 * 3.重复数据处理，同一用户的重复数据映射成多行,
 */
object Id_mapping {
  val sdf = new SimpleDateFormat("yyyyMMddHHmmss")

  def main(args: Array[String]): Unit = {

    //val args = Array("{'input_csv':'C:/Users/hyt/Desktop/national_tax.csv','csv_head':'nsrsbh, identity','table_name':'data_bridge.id_mapping','mapping_colume':'identity=identity'}")
    //val args = Array("{'input_csv':'C:/Users/hyt/Desktop/jsyh_idmapping.txt','csv_head':'identity, client_no','table_name':'data_bridge.id_mapping','mapping_colume':'xxx=xxx'}")

    val json = new JsonParser()
    val obj = json.parse(args(0).replace(" ", "")).asInstanceOf[JsonObject]
    val input_csv = obj.get("input_csv").getAsString
    val table_name = obj.get("table_name").getAsString
    val csv_head = obj.get("csv_head").getAsString.split(",")
    val mapping_colume = obj.get("mapping_colume").getAsString.split(",")
    val sparkSession = SparkSession.builder()
      .appName("id_mapping")
      .master("local[*]")
      .enableHiveSupport() //开启支持hive
      .getOrCreate()

    val dataFrame = sparkSession.sqlContext.sql("select * from " + table_name)
    val field = dataFrame.schema.fields.map(b => b.name)
    val all_colume = field.++(csv_head.filter(a => !mapping_colume.map(b => b.split("=")(1)).contains(a))).distinct

    // 生成匹配规则map
    val mapp = mutable.Map.empty[String, String]
    for (colume <- mapping_colume) {
      val columes = colume.split("=")
      mapp.+=((columes(1), columes(0)))
    }

    // 原始id_mapping表
    val table_rdd = dataFrame.rdd.map(row => {
      val set = mutable.Set.empty[String]
      for (i <- 0 to (field.size - 1)) {
        if (StringUtils.isNotBlank(row.getString(i))) {
          set.add(field(i) + "__" + row.getString(i))
        }
      }
      set
    })

    // 新增的csv文件
    val csv_rdd = sparkSession.sparkContext.textFile(input_csv).map(line => {
      val lines = line.split(",")
      val set = mutable.Set.empty[String]
      for (i <- 0 to (csv_head.size - 1)) {
        if (StringUtils.isNotBlank(lines(i))) {
          if (mapp.contains(csv_head(i)))
            set.add(mapp(csv_head(i)) + "__" + lines(i))
          else
            set.add(csv_head(i) + "__" + lines(i))
        }
      }
      set
    })
    // MR1 数据打散
    val rdd1: RDD[(String, (mutable.Set[String], mutable.Set[String], Int))] = table_rdd.++(csv_rdd)
      .flatMap { set =>
      set.map(t => (t, (set, 1)))
    }.reduceByKey { (t1, t2) =>
      t1._1 ++= t2._1
      val added = t1._2 + t2._2
      (t1._1, added)
    }.map { t =>
      (t._1, (t._2._1, mutable.Set.empty[String], t._2._2))
    }

    // MR2 id强打通
    val rdd2: RDD[(String, (mutable.Set[String], mutable.Set[String], Int))] = rdd1
      .flatMap(flatIdSet).reduceByKey(tuple3Add)
    val rdd3: RDD[(String, (mutable.Set[String], mutable.Set[String], Int))] = rdd2
      .flatMap(flatIdSet).reduceByKey(tuple3Add)

    // 去重
    val rdd4 = rdd3.filter { t =>
      t._2._2 += t._1
      t._2._3 == 1 || (t._2._1 -- t._2._2).isEmpty
    }.map(_._2._1).distinct()

    // MR3 重复数据处理
    // 同一用户重复数据，映射为多行
    val rdd5 = rdd4.flatMap(set => {
      val line = mutable.ListBuffer.empty[mutable.Set[String]]
      all_colume.foreach(a => {
        line.+=(set.filter(b => b.startsWith(a)))
      })
      cartesion_product(line).map(list => Row.fromSeq(list.map(t => if (StringUtils.isBlank(t)) "" else t.substring(t.indexOf("__") + 2))))
    })

    //保存hive
    val schema = StructType(all_colume.map(col_name => StructField(col_name, StringType, true)))
    val result = sparkSession.createDataFrame(rdd5, schema)
    val new_table = table_name + "_new_" + sdf.format(new Date())
    result.write.mode(SaveMode.Overwrite).saveAsTable(new_table)
    sparkSession.sqlContext.sql(String.format("alter table %s rename to %s", table_name, table_name + "_bak_" + sdf.format(new Date())))
    sparkSession.sqlContext.sql(String.format("alter table %s rename to %s", new_table, table_name))
  }

  // flat id_set
  def flatIdSet(row: (String, (mutable.Set[String], mutable.Set[String], Int))): Array[(String, (mutable.Set[String], mutable.Set[String], Int))] = {
    row._2._3 match {
      case 1 =>
        Array((row._1, (row._2._1, row._2._2, row._2._3)))
      case _ =>
        row._2._2 += row._1 // add key to keySet
        row._2._1.map(d => (d, (row._2._1, row._2._2, row._2._3))).toArray
    }
  }

  def tuple3Add(t1: (mutable.Set[String], mutable.Set[String], Int),
                t2: (mutable.Set[String], mutable.Set[String], Int)) = {
    t1._1 ++= t2._1
    t1._2 ++= t2._2
    val added = t1._3 + t2._3
    (t1._1, t1._2, added)
  }

  def cartesion_product(arrs: mutable.ListBuffer[mutable.Set[String]]): List[mutable.ListBuffer[String]] = {
    arrs.foldLeft(List[mutable.ListBuffer[String]]()) { (cumArr, addArr) => {
      if (cumArr.isEmpty && addArr.isEmpty) List(mutable.ListBuffer[String](""))
      else if (cumArr.isEmpty) addArr.map(t2 => mutable.ListBuffer(t2)).toList
      else if (addArr.isEmpty) cumArr.map(list => list.:+(""))
      else cumArr.flatMap { t1 => addArr.map(t2 => t1.:+(t2)) }
    }
    }.map(list => {
      try {
        val digest = MessageDigest.getInstance("MD5")
        list(0) = digest.digest(list(1).getBytes).map("%02x".format(_)).mkString
      } catch {
        case e: Exception => {
          println(e.printStackTrace)
        }
      }
      list
    })
  }
}
